{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2022, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}

{%- import "../util/util.jinja" as util with context -%}

#pragma once

#include <Eigen/Core>
#include <sym/pose3.h>
#include <sym/camera.h>

namespace sym {

/**
{% for line in doc.cls.split('\n') %}
 *{{ ' {}'.format(line).rstrip() }}
{% endfor %}
 */
template <typename CameraCalType>
class PosedCamera : public Camera<CameraCalType> {
 public:
  using Scalar = typename CameraCalType::Scalar;

  PosedCamera(const sym::Pose3<Scalar>& pose, const CameraCalType& calibration, const Eigen::Vector2i& image_size=Eigen::Vector2i(-1, -1))
    : Camera<CameraCalType>(calibration, image_size),
      pose_(pose) {}
  {{ util.print_docstring(doc.pixel_from_global_point) | indent(2) }}
  Eigen::Matrix<Scalar, 2, 1> PixelFromGlobalPoint(const Eigen::Matrix<Scalar, 3, 1>& point, const Scalar epsilon, Scalar* const is_valid) const {
    const Eigen::Matrix<Scalar, 3, 1> camera_point = pose_.InverseCompose(point);
    const Eigen::Matrix<Scalar, 2, 1> pixel = Camera<CameraCalType>::PixelFromCameraPoint(camera_point, epsilon, is_valid);
    return pixel;
  }
  {{ util.print_docstring(doc.global_point_from_pixel) | indent(2) }}
  Eigen::Matrix<Scalar, 3, 1> GlobalPointFromPixel(const Eigen::Matrix<Scalar, 2, 1>& pixel, Scalar range_to_point, const Scalar epsilon, Scalar* const is_valid) const {
    const Eigen::Matrix<Scalar, 3, 1> camera_ray = Camera<CameraCalType>::CameraRayFromPixel(pixel, epsilon, is_valid);
    const Eigen::Matrix<Scalar, 3, 1> camera_point = range_to_point * camera_ray.normalized();
    const Eigen::Matrix<Scalar, 3, 1> global_point = pose_ * camera_point;
    return global_point;
  }

  {{ util.print_docstring(doc.warp_pixel) | indent(2) }}
  Eigen::Matrix<Scalar, 2, 1> WarpPixel(const Eigen::Matrix<Scalar, 2, 1>& pixel, Scalar inverse_range, const PosedCamera& target_cam, const Scalar epsilon, Scalar* const is_valid) const {
    // See warp_pixel() in posed_camera.py.
    Scalar is_valid_ray;
    const Eigen::Matrix<Scalar, 3, 1> camera_ray = Camera<CameraCalType>::CameraRayFromPixel(pixel, epsilon, &is_valid_ray);
    const Eigen::Matrix<Scalar, 3, 1> camera_point = camera_ray.normalized();

    const sym::Pose3<Scalar>& target_T_self = target_cam.Pose().Inverse() * Pose();
    const Eigen::Matrix<Scalar, 3, 1> transformed_point = target_T_self.Rotation() * camera_point + (target_T_self.Position() * inverse_range);

    Scalar is_valid_projection;
    const Eigen::Matrix<Scalar, 2, 1> target_pixel = target_cam.PixelFromCameraPoint(transformed_point, epsilon, &is_valid_projection);

    *is_valid = is_valid_ray * is_valid_projection;
    return target_pixel;
  }

  const sym::Pose3<Scalar>& Pose() const {
    return pose_;
  }

  private:
    sym::Pose3<Scalar> pose_;
};

}  // namespace sym
